{
  "nbformat": 4,
  "nbformat_minor": 2,
  "metadata": {
    "colab": {
      "name": "Copy of Copy of torchaudio_MVDR_tutorial.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3.9.6 64-bit ('dev': conda)"
    },
    "language_info": {
      "name": "python",
      "version": "3.9.6",
      "mimetype": "text/x-python",
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "pygments_lexer": "ipython3",
      "nbconvert_exporter": "python",
      "file_extension": ".py"
    },
    "interpreter": {
      "hash": "6a702c257b9a40163843ba760790c17a6ddd2abeef8febce55475eea4b92c28c"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "<a href=\"https://colab.research.google.com/github/nateanl/audio/blob/mvdr/examples/beamforming/MVDR_tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ],
      "metadata": {
        "id": "xheYDPUcYGbp"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "This is a tutorial on how to apply MVDR beamforming by using [torchaudio](https://github.com/pytorch/audio)\n",
        "-----------\n",
        "\n",
        "The multi-channel audio example is selected from [ConferencingSpeech](https://github.com/ConferencingSpeech/ConferencingSpeech2021) dataset. \n",
        "\n",
        "```\n",
        "original filename: SSB07200001\\#noise-sound-bible-0038\\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\\#15217\\#25.16333303751458\\#0.2101221178590021.wav\n",
        "```\n",
        "\n",
        "Note:\n",
        "- You need to use the nightly torchaudio in order to use the MVDR and InverseSpectrogram modules.\n",
        "\n",
        "\n",
        "Steps\n",
        "\n",
        "- Ideal Ratio Mask (IRM) is generated by dividing the clean/noise magnitude by the mixture magnitude.\n",
        "- We test all three solutions (``ref_channel``, ``stv_evd``, ``stv_power``) of torchaudio's MVDR module.\n",
        "- We test the single-channel and multi-channel masks for MVDR beamforming. The multi-channel mask is averaged along channel dimension when computing the covariance matrices of speech and noise, respectively."
      ],
      "metadata": {
        "id": "L6R0MXe5Wr19"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "!pip install --pre torchaudio -f https://download.pytorch.org/whl/nightly/torch_nightly.html --force"
      ],
      "outputs": [],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "juO6PE9XLctD",
        "outputId": "8777ba14-da99-4c18-d80f-b070ad9861af"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "import torch\n",
        "import torchaudio\n",
        "import IPython.display as ipd"
      ],
      "outputs": [],
      "metadata": {
        "id": "T4u4unhFMMBG"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Load audios of mixture, reverberated clean speech, and dry clean speech."
      ],
      "metadata": {
        "id": "bDILVXkeg2s3"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "!curl -LJO https://github.com/nateanl/torchaudio_mvdr_tutorial/raw/main/wavs/mix.wav\n",
        "!curl -LJO https://github.com/nateanl/torchaudio_mvdr_tutorial/raw/main/wavs/reverb_clean.wav\n",
        "!curl -LJO https://github.com/nateanl/torchaudio_mvdr_tutorial/raw/main/wavs/clean.wav"
      ],
      "outputs": [],
      "metadata": {
        "id": "2XIyMa_VKv0c",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "404f46a6-e70c-4f80-af8d-d356408a9f18"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "mix, sr = torchaudio.load('mix.wav')\n",
        "reverb_clean, sr2 = torchaudio.load('reverb_clean.wav')\n",
        "clean, sr3 = torchaudio.load('clean.wav')\n",
        "assert sr == sr2\n",
        "noise = mix - reverb_clean"
      ],
      "outputs": [],
      "metadata": {
        "id": "iErB6UhQPtD3"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Note: The MVDR Module requires ``torch.cdouble`` dtype for noisy STFT. We need to convert the dtype of the waveforms to ``torch.double``"
      ],
      "metadata": {
        "id": "Aq-x_fo5VkwL"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "mix = mix.to(torch.double)\n",
        "noise = noise.to(torch.double)\n",
        "clean = clean.to(torch.double)\n",
        "reverb_clean = reverb_clean.to(torch.double)"
      ],
      "outputs": [],
      "metadata": {
        "id": "5c66pHcQV0P9"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Initilize the Spectrogram and InverseSpectrogram modules"
      ],
      "metadata": {
        "id": "05D26we0V4P-"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "stft = torchaudio.transforms.Spectrogram(n_fft=1024, hop_length=256, return_complex=True, power=None)\n",
        "istft = torchaudio.transforms.InverseSpectrogram(n_fft=1024, hop_length=256)"
      ],
      "outputs": [],
      "metadata": {
        "id": "NcGhD7_TUKd1"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Compute the complex-valued STFT of mixture, clean speech, and noise"
      ],
      "metadata": {
        "id": "-dlJcuSNUCgA"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "spec_mix = stft(mix)\n",
        "spec_clean = stft(clean)\n",
        "spec_reverb_clean = stft(reverb_clean)\n",
        "spec_noise = stft(noise)"
      ],
      "outputs": [],
      "metadata": {
        "id": "w1vO7w1BUKt4"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Generate the Ideal Ratio Mask (IRM)\n",
        "Note: we found using the mask directly peforms better than using the square root of it. This is slightly different from the definition of IRM."
      ],
      "metadata": {
        "id": "8SBchrDhURK1"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "def get_irms(spec_clean, spec_noise, spec_mix):\n",
        "    mag_mix = spec_mix.abs() ** 2\n",
        "    mag_clean = spec_clean.abs() ** 2\n",
        "    mag_noise = spec_noise.abs() ** 2\n",
        "    irm_speech = mag_clean / (mag_clean + mag_noise)\n",
        "    irm_noise = mag_noise / (mag_clean + mag_noise)\n",
        "\n",
        "    return irm_speech, irm_noise"
      ],
      "outputs": [],
      "metadata": {
        "id": "2gB63BoWUmHZ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Note: We use reverberant clean speech as the target here, you can also set it to dry clean speech"
      ],
      "metadata": {
        "id": "reGMDyNCaE7L"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise, spec_mix)"
      ],
      "outputs": [],
      "metadata": {
        "id": "HSTCGy_5Uqzx"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Apply MVDR beamforming by using multi-channel masks"
      ],
      "metadata": {
        "id": "1R5I_TmSUbS0"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "results_multi = {}\n",
        "for solution in ['ref_channel', 'stv_evd', 'stv_power']:\n",
        "    mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=True)\n",
        "    stft_est = mvdr(spec_mix, irm_speech, irm_noise)\n",
        "    est = istft(stft_est, length=mix.shape[-1])\n",
        "    results_multi[solution] = est"
      ],
      "outputs": [],
      "metadata": {
        "id": "SiWFZgCbadz7"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Apply MVDR beamforming by using single-channel masks \n",
        "(We use the 1st channel as an example. The channel selection may depend on the design of the microphone array)"
      ],
      "metadata": {
        "id": "Ukez6_lcUfna"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "results_single = {}\n",
        "for solution in ['ref_channel', 'stv_evd', 'stv_power']:\n",
        "    mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=False)\n",
        "    stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])\n",
        "    est = istft(stft_est, length=mix.shape[-1])\n",
        "    results_single[solution] = est"
      ],
      "outputs": [],
      "metadata": {
        "id": "kLeNKsk-VLm5"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Compute Si-SDR scores"
      ],
      "metadata": {
        "id": "uJjJNdYiUnf0"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "def si_sdr(estimate, reference, epsilon=1e-8):\n",
        "    estimate = estimate - estimate.mean()\n",
        "    reference = reference - reference.mean()\n",
        "    reference_pow = reference.pow(2).mean(axis=1, keepdim=True)\n",
        "    mix_pow = (estimate * reference).mean(axis=1, keepdim=True)\n",
        "    scale = mix_pow / (reference_pow + epsilon)\n",
        "\n",
        "    reference = scale * reference\n",
        "    error = estimate - reference\n",
        "\n",
        "    reference_pow = reference.pow(2)\n",
        "    error_pow = error.pow(2)\n",
        "\n",
        "    reference_pow = reference_pow.mean(axis=1)\n",
        "    error_pow = error_pow.mean(axis=1)\n",
        "\n",
        "    sisdr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)\n",
        "    return sisdr.item()"
      ],
      "outputs": [],
      "metadata": {
        "id": "MgmAJcyiU-FU"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Single-channel mask results"
      ],
      "metadata": {
        "id": "3TCJEwTOUxci"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "for solution in results_single:\n",
        "    print(solution+\": \", si_sdr(results_single[solution][None,...], reverb_clean[0:1]))"
      ],
      "outputs": [],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NrUXXj98VVY7",
        "outputId": "bc113347-70e3-47a9-8479-8aeeeca80abf"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Multi-channel mask results"
      ],
      "metadata": {
        "id": "-7AnjM-gU3c8"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "for solution in results_multi:\n",
        "    print(solution+\": \", si_sdr(results_multi[solution][None,...], reverb_clean[0:1]))"
      ],
      "outputs": [],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "S_VINTnlXobM",
        "outputId": "234b5615-63e7-44d8-f816-a6cc05999e52"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Display the mixture audio"
      ],
      "metadata": {
        "id": "_vOK8vgmU_UP"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "print(\"Mixture speech\")\n",
        "ipd.Audio(mix[0], rate=16000)"
      ],
      "outputs": [],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 92
        },
        "id": "QaKauQIHYctE",
        "outputId": "674c7f9b-62a3-4298-81ac-d3ab1ee43cd7"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Display the noise"
      ],
      "metadata": {
        "id": "R-QGGm87VFQI"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "print(\"Noise\")\n",
        "ipd.Audio(noise[0], rate=16000)"
      ],
      "outputs": [],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 92
        },
        "id": "l1WgzxIZYhlk",
        "outputId": "7b100679-b4a0-47ff-b30b-9f4cb9dca3d1"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Display the clean speech"
      ],
      "metadata": {
        "id": "P3kB-jzpVKKu"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "print(\"Clean speech\")\n",
        "ipd.Audio(clean[0], rate=16000)"
      ],
      "outputs": [],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 92
        },
        "id": "pwAWvlRAVJkT",
        "outputId": "5e173a1b-2ba8-4797-8f3a-e41cbf05ac2b"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Display the enhanced audios¶"
      ],
      "metadata": {
        "id": "RIlyzL1wVTnr"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "print(\"multi-channel mask, ref_channel solution\")\n",
        "ipd.Audio(results_multi['ref_channel'], rate=16000)"
      ],
      "outputs": [],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 92
        },
        "id": "M3YQsledVIQ5",
        "outputId": "43d9ee34-6933-401b-baf9-e4cdb7d79b63"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "print(\"multi-channel mask, stv_evd solution\")\n",
        "ipd.Audio(results_multi['stv_evd'], rate=16000)"
      ],
      "outputs": [],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 92
        },
        "id": "UhYOHLvCVWBN",
        "outputId": "761468ec-ebf9-4b31-ad71-bfa2e15fed37"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "print(\"multi-channel mask, stv_power solution\")\n",
        "ipd.Audio(results_multi['stv_power'], rate=16000)"
      ],
      "outputs": [],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 92
        },
        "id": "9dv8VDtCVXzd",
        "outputId": "1ae61ea3-d3c4-479f-faad-7439f942aac1"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "print(\"single-channel mask, ref_channel solution\")\n",
        "ipd.Audio(results_single['ref_channel'], rate=16000)"
      ],
      "outputs": [],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 92
        },
        "id": "jCFUN890VZdh",
        "outputId": "c0d2a928-5dd0-4584-b277-7838ac4a9e6b"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "print(\"single-channel mask, stv_evd solution\")\n",
        "ipd.Audio(results_single['stv_evd'], rate=16000)"
      ],
      "outputs": [],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 92
        },
        "id": "hzlzagsKVbAv",
        "outputId": "96af9e37-82ca-4544-9c08-421fe222bde4"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "print(\"single-channel mask, stv_power solution\")\n",
        "ipd.Audio(results_single['stv_power'], rate=16000)"
      ],
      "outputs": [],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 92
        },
        "id": "A4igQpTnVctG",
        "outputId": "cf968089-9274-4c1c-a1a5-32b220de0bf9"
      }
    }
  ]
}